# Credit: most of the code to train the model is from
# https://pytorch.org/tutorials/beginner/introyt/ ...
# /trainingyt.html?highlight=nn%20crossentropyloss

import numpy as np
import pickle as pkl
import os, sys

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', 
                                                 train=True, 
                                                 transform=transform, 
                                                 download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', 
                                                   train=False, 
                                                   transform=transform, 
                                                   download=True)

run_type = sys.argv[1]
run_id = sys.argv[2]
run_path = os.path.join(run_type, run_id)
if not os.path.exists(run_path): 
    os.mkdir(run_path)

n = len(training_set)
m = int(n/2)

# Class labels
classes = training_set.classes
K = len(classes)

if run_type=='subbagging':
    inds = np.random.choice(n, m, replace=False)
    training_set  = torch.utils.data.Subset(training_set, inds)
elif run_type=='loo':
    drop_ind = int(sys.argv[3])
    inds = np.delete(np.arange(n), drop_ind)
    training_set  = torch.utils.data.Subset(training_set, inds)
else: 
    inds = np.arange(n)

with open(os.path.join(run_path, 'indices.pkl'), 'wb') as f:
    pkl.dump(inds, f)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, 
                                              batch_size=4, 
                                              shuffle=True)
# important that shuffle=False
validation_loader = torch.utils.data.DataLoader(validation_set, 
                                                batch_size=4, 
                                                shuffle=False)

import torch.nn as nn
import torch.nn.functional as F

# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = GarmentClassifier()

loss_fn = torch.nn.CrossEntropyLoss()

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

# Initializing in a separate cell so we 
# can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = os.path.join(
            run_path,
            'model_{}_{}'.format(timestamp, epoch_number))
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

N = len(validation_set)
batches = np.array_split(np.arange(N), int(N/4))
w = np.zeros((N, K))
with torch.no_grad():
    for i, vdata in enumerate(validation_loader):
        vinputs, vlabels = vdata
        # logits
        voutputs = model(vinputs)
        w[batches[i]] = F.softmax(voutputs, dim=1)
w = w/w.sum(1)[:, np.newaxis]

with open(os.path.join(run_path, 'weights.pkl'), 'wb') as f:
    pkl.dump(w, f)